package edu.uky.ai.ml.nn;

import java.util.ArrayList;
import java.util.Random;

/**
 * A neural network is a computational model, based roughly on biological
 * networks of neurons, which can be trained to perform well on various
 * kinds of classification and non-linear regression tasks.  This particular
 * implementation is a feed-forward network with 0 or more hidden layers.
 * 
 * @author Stephen G. Ware
 */
public class NeuralNetwork {

	/** The input neurons */
	public final InputLayer input;
	
	/** The output neurons */
	public final Layer output;
	
	/** All the weighted directed edges present in the network */
	public final Iterable<Edge> edges = new ArrayList<>();
	
	/**
	 * Constructs a new neural network of a given size.
	 * 
	 * @param input the number of input neurons
	 * @param output the number of output neurons
	 * @param hiddenLayers the number of hidden layers
	 * @param hiddenNodes the number of neurons per hidden layer
	 * @param random a random number generator for setting the initial edge weights
	 */
	public NeuralNetwork(int input, int output, int hiddenLayers, int hiddenNodes, Random random) {
		this.input = new InputLayer(this, input);
		Layer previous = this.input;
		for(int i=0; i<hiddenLayers; i++)
			previous = new Layer(this, previous, hiddenNodes, random);
		this.output = new Layer(this, previous, output, random);
	}
	
	/**
	 * Sets the input neurons to the given values.
	 * 
	 * @param values the values
	 */
	public void setInput(double[] values) {
		if(values.length != input.neurons.length)
			throw new IllegalArgumentException("Expected " + input.neurons.length + " input values; " + values.length + " given.");
		for(int i=0; i<values.length; i++)
			input.neurons[i].setValue(values[i]);
	}
	
	/**
	 * Calculates and returns the values of the output neurons given the
	 * current state of the network.
	 * 
	 * @return the values of all output neurons
	 */
	public double[] getOutput() {
		double[] output = new double[this.output.neurons.length];
		for(int i=0; i<output.length; i++)
			output[i] = this.output.neurons[i].getValue();
		return output;
	}
	
	/**
	 * Returns the total error of the network for a given training database.
	 * The total error is simply the sum of all errors on the training
	 * examples.
	 * 
	 * @param database the training database
	 * @return the total error
	 */
	public double getError(Database database) {
		double total = 0;
		for(Example example : database) {
			setInput(example.input);
			double[] output = getOutput();
			total += Error.evaluate(example.output, output);
		}
		return total;
	}
	
	/**
	 * Given a database of examples and a set of values for the input neurons,
	 * this method calculates the output of the network and returns the class
	 * label of the training example whose output is most similar to the
	 * network's output.
	 * 
	 * @param database the training database
	 * @param input the values for the input neurons
	 * @return the class label of the example most similar to the network output
	 */
	public String classify(Database database, double[] input) {
		// Start by setting the values of the input neurons.  You can use
		// #setInput(double[]) for this.
		
		// Calculate the output of the network using #getOutput().
		double[] output;
		// This variable will hold the class label that will eventually be
		// returned.
		String bestLabel = null;
		// This variable will keep track of the lowest error found so far.  In
		// other words, this value represents how closely the best class label
		// matches the network's output.
		double lowestError = Double.POSITIVE_INFINITY;
		// Loop through all the examples in the database.
		
		// Calculate the error (difference) between the example's output
		// and the output of the network.  You can use
		// Error#evaluate(double[], double[]) for this.
		double error;
		// If this example's error is lower than the lowest one discovered
		// so far, update 'bestLabel' to be the class label of this example
		// and 'lowestError' to be the error of this example.
		
		// Return the class label of the example most closely matching the
		// network's output for the given input.
		return bestLabel;
	}
	
	/**
	 * Quantifies how will this network performs on a given database of
	 * training examples.  Accuracy is simple the number of correctly
	 * classified examples divided by the total number of examples.
	 * 
	 * @param database the database of examples
	 * @return the accuracy
	 */
	public double getAccuracy(Database database) {
		// Keeps track of the number of examples classified correctly.
		double correct;
		// Loop though all examples in the database.
		
		// Use #classify(Database, double[]) to see which class label the
		// network returns for a given example.  If that class label is
		// correct, increment the number of correct answers.
		
		// Return the number of correct examples divided by the total number of
		// examples.
		return 0;
	}
}
